load("//tensorflow:tensorflow.bzl", "if_google", "if_oss", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load(
    "//tensorflow/core/platform:build_config_root.bzl",
    "if_llvm_aarch32_available",
    "if_llvm_aarch64_available",
    "if_llvm_hexagon_available",
    "if_llvm_powerpc_available",
    "if_llvm_system_z_available",
    "if_llvm_x86_available",
)
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = ["//visibility:private"],
    licenses = ["notice"],
)

# Don't depend on this directly; this is only used for the benchmark test
# generated by tf_library.
cc_library(
    name = "tf_library_test_main",
    testonly = 1,
    visibility = ["//visibility:public"],
    deps = if_oss([
        "//tensorflow/core:test_main",
    ]) + if_google([
        "//tensorflow/core/platform/default/build_config:test_main",
    ]),
)

filegroup(
    name = "quantize_header",
    srcs = ["quantize.h"],
    visibility = ["//visibility:public"],
)

cc_library(
    name = "thunk_proto_execution_deserializer",
    srcs = ["thunk_proto_execution_deserializer.cc"],
    hdrs = ["thunk_proto_execution_deserializer.h"],
    visibility = ["//visibility:public"],
    deps = [
        "@com_google_absl//absl/numeric:int128",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@local_xla//xla:debug_options_flags",
        "@local_xla//xla:shape_util",
        "@local_xla//xla:status_macros",
        "@local_xla//xla:util",
        "@local_xla//xla:xla_data_proto_cc",
        "@local_xla//xla/backends/cpu/runtime:convolution_lib",
        "@local_xla//xla/backends/cpu/runtime:dot_lib",
        "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc",
        "@local_xla//xla/service/cpu:executable_proto_cc",
        "@local_xla//xla/tsl/platform:statusor",
    ],
)

cc_library(
    name = "tfcompile_lib",
    srcs = [
        "codegen.cc",
        "compile.cc",
        "flags.cc",
    ],
    hdrs = [
        "codegen.h",
        "compile.h",
        "flags.h",
        "quantize.h",
    ],
    compatible_with = [],
    defines = if_llvm_aarch32_available(["TF_LLVM_AARCH32_AVAILABLE=1"]) + if_llvm_aarch64_available([
        "TF_LLVM_AARCH64_AVAILABLE=1",
    ]) + if_llvm_hexagon_available([
        "TF_LLVM_HEXAGON_AVAILABLE=1",
    ]) + if_llvm_powerpc_available([
        "TF_LLVM_POWERPC_AVAILABLE=1",
    ]) + if_llvm_system_z_available([
        "TF_LLVM_S390X_AVAILABLE=1",
    ]) + if_llvm_x86_available([
        "TF_LLVM_X86_AVAILABLE=1",
    ]),
    visibility = [
        "//tensorflow:__pkg__",
        "//tensorflow/python:__pkg__",
    ],
    deps = [
        ":aot_only_var_handle_op",
        ":embedded_constant_buffers",
        ":embedded_protocol_buffers",
        ":llvm_targets",  # fixdeps: keep
        ":thunk_proto_execution_deserializer",
        "//tensorflow/compiler/tf2xla",
        "//tensorflow/compiler/tf2xla:allocator",
        "//tensorflow/compiler/tf2xla:encoded_buffer_allocation_info",
        "//tensorflow/compiler/tf2xla:mlir_tf2xla",  # fixdeps: keep
        "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
        "//tensorflow/compiler/tf2xla:tf2xla_util",
        "//tensorflow/compiler/tf2xla:xla_compiler",
        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
        "//tensorflow/core:core_cpu_internal",
        "//tensorflow/core:framework_internal",
        "//tensorflow/core:lib",
        "//tensorflow/core:lib_internal",
        "//tensorflow/core:protos_all_cc",
        "//tensorflow/core/platform:regexp",
        "//tensorflow/core/platform:status",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/base",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:Target",
        "@local_xla//xla:debug_options_flags",
        "@local_xla//xla:shape_util",
        "@local_xla//xla:status_macros",
        "@local_xla//xla:util",
        "@local_xla//xla:xla_data_proto_cc",
        "@local_xla//xla/backends/cpu:buffer_allocation_info",
        "@local_xla//xla/backends/cpu:buffer_allocation_info_util",
        "@local_xla//xla/backends/cpu/codegen:symbol_name_util",
        "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc",
        "@local_xla//xla/backends/cpu/runtime:thunk_proto_serdes",
        "@local_xla//xla/client:client_library",
        "@local_xla//xla/client:compile_only_client",
        "@local_xla//xla/hlo/builder:xla_computation",
        "@local_xla//xla/service:compiler",
        "@local_xla//xla/service/cpu:cpu_aot_compilation_result",
        "@local_xla//xla/service/cpu:cpu_compiler",
        "@local_xla//xla/service/cpu:cpu_executable",
        "@local_xla//xla/stream_executor:platform",
        "@local_xla//xla/stream_executor:platform_manager",
    ],
)

tf_cc_test(
    name = "codegen_test",
    srcs = ["codegen_test.cc"],
    defines = if_llvm_x86_available(["TF_LLVM_X86_AVAILABLE=1"]),
    deps = [
        ":tfcompile_lib",
        "//tensorflow/core:lib",
        "//tensorflow/core:protos_all_cc",
        "//tensorflow/core:test",
        "//tensorflow/core:test_main",
        "//tensorflow/core/platform:resource_loader",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",  # fixdeps: keep
        "@local_xla//xla:shape_util",
        "@local_xla//xla/service/cpu:cpu_aot_compilation_result",
    ] + if_llvm_x86_available([
        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
    ]),
)

tf_cc_binary(
    name = "tfcompile",
    visibility = ["//visibility:public"],
    deps = [":tfcompile_main"],
)

cc_library(
    name = "llvm_targets",
    visibility = ["//tensorflow/python:__pkg__"],
    deps = [
    ] + if_llvm_aarch32_available([
        "@llvm-project//llvm:ARMAsmParser",  # fixdeps: keep
        "@llvm-project//llvm:ARMCodeGen",  # fixdeps: keep
    ]) + if_llvm_aarch64_available([
        "@llvm-project//llvm:AArch64AsmParser",  # fixdeps: keep
        "@llvm-project//llvm:AArch64CodeGen",  # fixdeps: keep
    ]) + if_llvm_hexagon_available([
        "@llvm-project//llvm:HexagonAsmParser",  # fixdeps: keep
        "@llvm-project//llvm:HexagonCodeGen",  # fixdeps: keep
    ]) + if_llvm_powerpc_available([
        "@llvm-project//llvm:PowerPCAsmParser",  # fixdeps: keep
        "@llvm-project//llvm:PowerPCCodeGen",  # fixdeps: keep
    ]) + if_llvm_system_z_available([
        "@llvm-project//llvm:SystemZAsmParser",  # fixdeps: keep
        "@llvm-project//llvm:SystemZCodeGen",  # fixdeps: keep
    ]) + if_llvm_x86_available([
        "@llvm-project//llvm:X86AsmParser",  # fixdeps: keep
        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
    ]),
)

cc_library(
    name = "tfcompile_main",
    srcs = ["tfcompile_main.cc"],
    compatible_with = [],
    visibility = ["//visibility:public"],
    deps = [
        ":tfcompile_lib",
        "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
        "//tensorflow/core:core_cpu_internal",
        "//tensorflow/core:framework_internal",
        "//tensorflow/core:lib",
        "//tensorflow/core:protos_all_cc",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@local_xla//xla:debug_options_flags",
        "@local_xla//xla/tsl/platform:status",
    ],
)

# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
# tfcompile.bzl correctly handles usage from outside of the package that it is
# defined in.

# A simple test of tf_library from a text protobuf, to enable benchmark_test.
# This test uses an incompleted graph with a node that is not defined. The
# compilation works because the undefined node is a feed node.
tf_library(
    name = "test_graph_tfadd",
    testonly = 1,
    config = "test_graph_tfadd.config.pbtxt",
    cpp_class = "AddComp",
    graph = "test_graph_tfadd.pbtxt",
    mlir_components = "None",
    tags = [
        "manual",
    ],
)

tf_library(
    name = "test_graph_tfadd_mlir_bridge",
    testonly = 1,
    config = "test_graph_tfadd.config.pbtxt",
    cpp_class = "AddComp",
    graph = "test_graph_tfadd.pbtxt",
    mlir_components = "Bridge",
    tags = [
        "manual",
    ],
)

# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the node with the unknown op is not needed
# for the fetches.
tf_library(
    name = "test_graph_tfunknownop",
    testonly = 1,
    config = "test_graph_tfunknownop.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "None",
    tags = [
        "manual",
    ],
)

tf_library(
    name = "test_graph_tfunknownop_mlir_bridge",
    testonly = 1,
    config = "test_graph_tfunknownop.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "Bridge",
    tags = [
        "manual",
    ],
)

# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the node with the unknown op is only used as
# an input of a feed node.
tf_library(
    name = "test_graph_tfunknownop2",
    testonly = 1,
    config = "test_graph_tfunknownop2.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "None",
    tags = [
        "manual",
    ],
)

tf_library(
    name = "test_graph_tfunknownop2_mlir_bridge",
    testonly = 1,
    config = "test_graph_tfunknownop2.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "Bridge",
    tags = [
        "manual",
    ],
)

# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the node with the unknown op is a feed node.
tf_library(
    name = "test_graph_tfunknownop3",
    testonly = 1,
    config = "test_graph_tfunknownop3.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "None",
    tags = [
        "manual",
    ],
)

tf_library(
    name = "test_graph_tfunknownop3_mlir_bridge",
    testonly = 1,
    config = "test_graph_tfunknownop3.config.pbtxt",
    cpp_class = "UnknownOpAddComp",
    graph = "test_graph_tfunknownop.pbtxt",
    mlir_components = "Bridge",
    tags = [
        "manual",
    ],
)

# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(
    name = "benchmark",
    srcs = ["benchmark.cc"],
    hdrs = ["benchmark.h"],
    visibility = ["//visibility:public"],
    deps = [
        # The purpose of the benchmark library is to support building an aot
        # binary with minimal dependencies, to demonstrate small binary sizes.
        #
        # KEEP THE DEPENDENCIES MINIMAL.
        "//tensorflow/core:framework_lite",
    ],
)

cc_library(
    name = "benchmark_extra_android",
    tags = [
        "manual",
    ],
    visibility = ["//visibility:public"],
)

cc_library(
    name = "embedded_protocol_buffers",
    srcs = ["embedded_protocol_buffers.cc"],
    hdrs = ["embedded_protocol_buffers.h"],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/core:lib",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Core",
        "@llvm-project//llvm:MC",
        "@llvm-project//llvm:Target",
        "@llvm-project//llvm:TargetParser",
        "@local_xla//xla:util",
        "@local_xla//xla/service/llvm_ir:llvm_type_conversion_util",
    ],
)

cc_library(
    name = "embedded_constant_buffers",
    srcs = ["embedded_constant_buffers.cc"],
    hdrs = ["embedded_constant_buffers.h"],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/core:lib",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Core",
        "@llvm-project//llvm:MC",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:Target",
        "@llvm-project//llvm:TargetParser",
        "@local_xla//xla:util",
        "@local_xla//xla/backends/cpu:alignment",
        "@local_xla//xla/service/llvm_ir:llvm_type_conversion_util",
    ],
)

tf_cc_test(
    name = "embedded_constant_buffers_test",
    srcs = ["embedded_constant_buffers_test.cc"],
    deps = [
        ":embedded_constant_buffers",
        ":llvm_targets",  # fixdeps: keep
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:statusor",
        "@local_xla//xla/service/cpu:test_header_helper",
    ],
)

cc_library(
    name = "aot_only_var_handle_op",
    srcs = ["aot_only_var_handle_op.cc"],
    hdrs = ["aot_only_var_handle_op.h"],
    visibility = [
        "//tensorflow/compiler/tf2xla:__pkg__",
    ],
    deps = [
        "//tensorflow/compiler/tf2xla:xla_compiler",
        "//tensorflow/compiler/tf2xla:xla_context",
        "//tensorflow/compiler/tf2xla:xla_op_registry",
        "//tensorflow/core:framework",
    ],
    alwayslink = 1,
)

tf_cc_test(
    name = "benchmark_test",
    srcs = ["benchmark_test.cc"],
    tags = ["manual"],
    deps = [
        ":benchmark",
        ":test_graph_tfadd",
        "//tensorflow/core:test",
        "//tensorflow/core:test_main",
    ],
)

test_suite(
    name = "all_tests",
    tags = ["manual"],
    tests = [
        ":benchmark_test",
        ":codegen_test",
        ":test_graph_tfadd_mlir_bridge_test",
        ":test_graph_tfadd_test",
        ":test_graph_tfunknownop2_mlir_bridge_test",
        ":test_graph_tfunknownop2_test",
        ":test_graph_tfunknownop3_mlir_bridge_test",
        ":test_graph_tfunknownop3_test",
        ":test_graph_tfunknownop_mlir_bridge_test",
        ":test_graph_tfunknownop_test",
        "//tensorflow/compiler/aot/tests:all_tests",
    ],
)

exports_files([
    "benchmark_main.template",  # used by tf_library(...,gen_benchmark=True)
    if_oss("test.cc", "test_google.cc"),  # used by tf_library(...,gen_test=True)
])
